import pandas as p
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from AnalyzeHistoricalNHTS import getVehicleDataByYear, getPersonDataByYear, getHouseholdDataByYear, get_state_safety_inspection_df_by_year, get_state_safety_inspection_dict, load_stata_data
from tqdm import tqdm, trange
from copy import deepcopy
# import geopandas as gp
import statsmodels as sm
from statsmodels.stats.weightstats import DescrStatsW
# p.options.mode.copy_on_write = True

# def load_map_data():
#     file = '/Users/connorforsythe/Library/CloudStorage/Box-Box/CMU/Data/Map Data/cb_2018_us_state_500k 2/cb_2018_us_state_500k.shp'
#     map = gp.read_file(file)
#     return map


def load_household_data(year=2017, small=False, window=2):
    hh_data = getHouseholdDataByYear(year, small)

    hh_data = hh_data.merge(get_state_treatment_year_dict(), on='State Code', how='outer')
    # treatment_years_to_drop = list(range(year-window, year+window+1))
    # hh_data = hh_data.loc[~(hh_data.loc[:, 'TreatmentYear'].isin(treatment_years_to_drop)), :]

    hh_data = apply_income_mapping(hh_data)

    person_data = load_person_data(year, small)

    if('R_AGE_IMP' not in person_data.columns and year==2009):
        person_data.loc[:, 'R_AGE_IMP'] = person_data.loc[:, 'R_AGE']
        person_data.loc[person_data.loc[:, 'R_AGE']==92, 'R_AGE_IMP'] = 89



    person_data.loc[:, 'Pre-Weighted Age'] = person_data.loc[:, 'R_AGE_IMP']*person_data.loc[:, 'WTPERFIN']
    person_data = person_data.loc[:, ['HOUSEID', 'R_AGE_IMP', 'Pre-Weighted Age', 'WTPERFIN']]
    grouped_person_data = person_data.groupby('HOUSEID', as_index=False).mean()



    grouped_person_data.loc[:, 'Age'] = grouped_person_data.loc[:, 'R_AGE_IMP']
    hh_data = hh_data.merge(grouped_person_data, on='HOUSEID', how='left')

    if (year == 2009):
        person_data.loc[:, 'YOUNGCHILD'] = 0
        person_data.loc[:, 'count'] = 1
        person_data.loc[
            (person_data.loc[:, 'R_AGE_IMP'] >= 0) & (person_data.loc[:, 'R_AGE_IMP'] <= 4), 'YOUNGCHILD'] = 1
        grouped_person_data_sum = person_data.groupby('HOUSEID', as_index=False).sum()
        hh_data = hh_data.merge(grouped_person_data_sum.loc[:, ['HOUSEID', 'YOUNGCHILD', 'count']], on='HOUSEID', how='left')

    return hh_data


def load_person_data(year=2017, small=False):
    person_data = getPersonDataByYear(year, small)

    return person_data

def apply_income_mapping(data):
    keys = list(range(1, 12))
    vals = [5, 12.5, 20, 30, 42.5, 62.5, 87.5, 112.5, 137.5, 175, 200]

    inc_dict = dict(zip(keys, vals))

    inc_dict[-9] = -1
    inc_dict[-8] = -1
    inc_dict[-7] = -1

    data.loc[:, 'Income'] = data.loc[:, 'HHFAMINC'].map(inc_dict)

    return data

def get_states_by_policy_status(data):
    columns = list(data.columns)
    safety_col_ind = columns.index('HasSafety')
    state_col_ind = columns.index('State Code')
    r = {'With Safety Inspections':[], 'Without Safety Inspections':[]}
    for row in tqdm(data.itertuples(index=False)):
        temp_safety_status = row[safety_col_ind]
        temp_state = row[state_col_ind]
        if(temp_safety_status):
            temp_key = 'With Safety Inspections'
        else:
            temp_key = 'Without Safety Inspections'

        if(temp_state not in r[temp_key]):
            r[temp_key].append(temp_state)
    return r
def get_state_treatment_year_dict(df=True):
    data = load_stata_data()
    data = data.loc[data.loc[:, "Year"]==1995]

    raw_r = dict(zip(data.loc[:, 'State Code'], data.loc[:, 'EndedSafety']))

    r = {}

    for key, val in raw_r.items():
        if(not np.isnan(val)):
            r[key] = int(val)


    if(df):
        r = p.DataFrame.from_dict(r, 'index', columns=['TreatmentYear'])
        r.loc[:, 'State Code'] = r.index

    return r

def compare_across_states(raw_data, group_1, group_2, columns):
    full_group = []
    full_group.extend(group_1)
    full_group.extend(group_2)

    data = raw_data.loc[raw_data.loc[:, 'State Code'].isin(full_group), :]
    data.loc[:, 'group'] = 0
    data.loc[data.loc[:, 'State Code'].isin(group_2), 'group'] = 1

    extreme_vals = {}

    cols_to_keep = ['group', 'WTHHFIN']
    for col in columns:
        temp_col = f'{col}-weighted'
        temp_col_scaled = f'{col}-weighted-scaled'
        extreme_vals[col] = (data[col].min(), data[col].max())
        cols_to_keep.append(temp_col)
        cols_to_keep.append(temp_col_scaled)
        data.loc[:, temp_col] = data.loc[:, col]*data.loc[:, 'WTHHFIN']
        data.loc[:, temp_col_scaled] = ((data.loc[:, col]-extreme_vals[col][0])/(extreme_vals[col][1]-extreme_vals[col][0])) * data.loc[:, 'WTHHFIN']

    grouped_data = data.loc[:, cols_to_keep].groupby('group', as_index=False).sum()
    r = {}
    total_sum_sq_diff = 0
    for col in columns:
        temp_col = f'{col}-weighted'
        temp_col_scaled = f'{col}-weighted-scaled'

        grouped_data.loc[:, temp_col+'-raw'] = grouped_data.loc[:, temp_col]
        grouped_data.loc[:, temp_col_scaled + '-raw'] = grouped_data.loc[:, temp_col_scaled]
        grouped_data.loc[:, temp_col] = grouped_data.loc[:, temp_col]/grouped_data.loc[:, 'WTHHFIN']
        grouped_data.loc[:, temp_col_scaled] = grouped_data.loc[:, temp_col_scaled] / grouped_data.loc[:, 'WTHHFIN']
        temp_sq_diff = (np.max(grouped_data.loc[:, temp_col_scaled])-np.min(grouped_data.loc[:, temp_col_scaled]))**2
        r[col] = temp_sq_diff
        total_sum_sq_diff+=temp_sq_diff

    r['total'] = total_sum_sq_diff


    return r, grouped_data

def get_potential_states(states_used, states_possible):
    r = []

    for state in states_possible:
        if(state not in states_used):
            r.append(state)

    return r

def calculate_all_outcomes(data, columns, treat_grouping, previous_grouping, possible_states_to_add):
    r = {}
    for state in tqdm(possible_states_to_add, disable=True):
        r[state] = {}
        r[state]['init_group'] = deepcopy(previous_grouping)
        temp_new_grouping = deepcopy(previous_grouping)
        temp_new_grouping.append(state)
        r[state]['new_group'] = deepcopy(temp_new_grouping)
        temp_outcome, grouped_data = compare_across_states(data, treat_grouping, temp_new_grouping, columns)
        r[state]['full_res'] = temp_outcome
        r[state]['grouped_data'] = grouped_data
        r[state]['obj_val'] = temp_outcome['total']
    return r

def determine_comparison_groups(data, columns, max_comparison_group_size=None):
    base_groups = get_states_by_policy_status(data)
    treat_group = base_groups['With Safety Inspections']
    potential_comparison_states = base_groups['Without Safety Inspections']

    if(max_comparison_group_size is None):
        max_comparison_group_size = len(potential_comparison_states)


    best_groupings = {}
    for num_states in trange(1, max_comparison_group_size+1):
        temp_best_obj_val = np.inf
        temp_best_group = None
        temp_best_group_res = None
        best_groupings[num_states] = {}
        if(num_states==1):
            previous_grouping = []
        else:
            previous_grouping = best_groupings[num_states - 1]['grouping']

        temp_possible_states = get_potential_states(previous_grouping, potential_comparison_states)
        temp_full_outcomes = calculate_all_outcomes(data, columns, treat_group, previous_grouping, temp_possible_states)
        best_groupings[num_states]['full_comparison'] = temp_full_outcomes
        for temp_state, res_dict in temp_full_outcomes.items():
            temp_obj_val = res_dict['obj_val']
            if(temp_obj_val<temp_best_obj_val):
                temp_best_obj_val = temp_obj_val
                temp_best_group = res_dict['new_group']
                temp_best_group_res = res_dict
        best_groupings[num_states]['grouping'] = temp_best_group
        best_groupings[num_states]['obj_val'] = temp_best_obj_val
        best_groupings[num_states]['res'] = temp_best_group_res

    return best_groupings


def get_best_comparison_group(data, columns, max_comparison_group_size = None):
    clean_data = remove_bad_entries(deepcopy(data), columns)
    full_comparison_set = determine_comparison_groups(clean_data, columns, max_comparison_group_size)
    r = {'full_comparison': full_comparison_set}
    best_val = np.inf
    best_res = None
    best_num = -1
    for num_state, res_dict in full_comparison_set.items():
        temp_val = res_dict['obj_val']
        if (temp_val < best_val):
            best_val = temp_val
            best_res = res_dict
            best_num = num_state


    r['best_num'] = best_num
    r['best_res'] = best_res
    r['best_val'] = best_val
    r['data'] = clean_data

    return r

# def plot_comparison_group(control_group = ['CA', 'OR'], x = 5, y=None, large=True):
#     base_groups = get_states_by_policy_status(load_household_data())
#
#     group_1 = base_groups['With Safety Inspections']
#     group_2 = control_group
#
#     num_treated = len(group_1)
#     num_control = len(group_2)
#
#     if(large):
#         title_large_val = 'Large'
#         save_large_val = 'large'
#     else:
#         title_large_val = 'Small'
#         save_large_val = 'small'
#
#     if(y is None):
#         gr = (1+np.sqrt(5))/2
#         y = x/gr
#
#     base_map = load_map_data()
#
#     non_contig_states = ['AS', 'MP', 'AK', 'HI', 'PR', 'GU', 'VI']
#
#     base_map = base_map.loc[~(base_map.loc[:, 'STUSPS'].isin(non_contig_states)), :]
#
#     print(len(list(base_map.loc[:, 'STUSPS'].unique())))
#
#     fig, ax = plt.subplots(1,1, figsize=(x,y))
#
#     base_map.boundary.plot(ax=ax, edgecolor='black', lw=0.2)
#
#     group_1_map = base_map.loc[(base_map.loc[:, 'STUSPS'].isin(group_1)), :]
#     group_2_map = base_map.loc[(base_map.loc[:, 'STUSPS'].isin(group_2)), :]
#
#     group_1_color = 'xkcd:blue'
#     group_2_color = 'xkcd:red'
#
#     group_1_plot = group_1_map.plot(ax=ax, color=group_1_color, lw=0, label='Treated States')
#     group_2_plot = group_2_map.plot(ax=ax, color=group_2_color, lw=0, label = 'Control States')
#
#
#     plt.xticks([], [])
#     plt.yticks([], [])
#     plt.title(f'Treated States and {num_control} Control States\nwith {title_large_val} Set of Demographic Comparisons')
#
#     xlim = plt.xlim()
#     xlim_diff = max(xlim)-min(xlim)
#
#     ylim = plt.ylim()
#     ylim_diff = max(ylim) - min(ylim)
#
#     plt.text(min(xlim)+(xlim_diff*.01), min(ylim)+(ylim_diff*.11), 'Treated States are Blue', color=group_1_color)
#     plt.text(min(xlim) + (xlim_diff * .01), min(ylim) + (ylim_diff * .01), 'Control States are Red', color=group_2_color)
#
#
#     plt.savefig(f'Maps/{save_large_val}.png', bbox_inches='tight', dpi=300)
#     plt.show()

def get_clean_name(raw_name):
    raw_names = ['NUMADLT', 'total', 'HHSIZE', 'Income', 'DRVRCNT', 'HBPPOPDN', 'WRKCOUNT', 'YOUNGCHILD', 'Age']
    clean_name = ['\# of Adults in Household', '\\hline Sum Total', '\# of People in Household', 'Household Income', '\# of Drivers in Household', 'Local Pop. Density', '\# of Workers in Household', '\# of Young Children in Household', 'Mean Household Age']
    name_dict = dict(zip(raw_names, clean_name))
    try:
        return name_dict[raw_name]
    except KeyError:
        return raw_name

def build_table(res_dict, mean_digits=2, std_digits=5, large=True):
    if (large):
        title_large_val = 'Large'
        save_large_val = 'large'
    else:
        title_large_val = 'Small'
        save_large_val = 'small'

    data = res_dict['data']
    base_groups = get_states_by_policy_status(load_household_data())
    treat_group = base_groups['With Safety Inspections']
    control_group = res_dict['best_res']['grouping']

    treat_data = data.loc[data.loc[:, 'State Code'].isin(treat_group), :]
    control_data = data.loc[data.loc[:, 'State Code'].isin(control_group), :]

    relevant_dict = res_dict['best_res']['res']['full_res']
    print(relevant_dict)
    rows = []

    weight_col = 'WTHHFIN'

    for key, val in relevant_dict.items():
        if(key != 'total'):
            clean_name = get_clean_name(key)
            temp_treat_sum_stat = DescrStatsW(list(treat_data.loc[:, key]), list(treat_data.loc[:, weight_col]))
            temp_control_sum_stat = DescrStatsW(list(control_data.loc[:, key]), list(control_data.loc[:, weight_col]))

            temp_treat_entry = f'{np.round(temp_treat_sum_stat.mean, mean_digits)} \\\\ ({np.round(temp_treat_sum_stat.std_mean, std_digits)})'
            temp_control_entry = f'{np.round(temp_control_sum_stat.mean, mean_digits)} \\\\ ({np.round(temp_control_sum_stat.std_mean, std_digits)})'

            temp_treat_entry = '\\makecell{' + temp_treat_entry + '}'
            temp_control_entry = '\\makecell{' + temp_control_entry + '}'

            clean_val = np.format_float_scientific(val, precision=3, trim='0')
            print(clean_name, clean_val)
            temp_row = [clean_name, temp_treat_entry, temp_control_entry]
            rows.append(temp_row)

    df = p.DataFrame(rows, columns=['Variable', 'Treated States', 'Control States'])
    df.to_latex(f'Tables/{save_large_val}.tex', index=False, escape=False)
    return df

def remove_bad_entries(data, columns):
    for column in columns:
        data = data.loc[data.loc[:, column] >= 0, :]

    return data

if __name__ == '__main__':
    hh = load_household_data()
    person = load_person_data(small=True)

    # a = get_state_safety_inspection_dict()
    # b = get_state_treatment_year_dict()
    # c = get_states_by_policy_status(hh)
    #
    #
    # d = compare_across_states(hh, c['With Safety Inspections'], c['Without Safety Inspections'], variables_to_consider)

    small_set = ['HHSIZE', 'DRVRCNT', 'Income', 'HBPPOPDN']
    large_set = deepcopy(small_set)
    large_set.extend(['WRKCOUNT', 'NUMADLT', 'YOUNGCHILD', 'Age'])

    res_small = get_best_comparison_group(hh, small_set)
    res_large = get_best_comparison_group(hh, large_set)


    plot_comparison_group(res_small['best_res']['grouping'], large=False)
    plot_comparison_group(res_large['best_res']['grouping'], large=True)
    small_table = build_table(res_small, large=False)
    large_table = build_table(res_large, large=True)